{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QyCcF45zBQ3E"
      },
      "source": [
        "##### Copyright 2018 - 2020 The TensorFlow Authors. [Licensed under the Apache License, Version 2.0](#scrollTo=y_UVSRtBBsJk)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "CPII1rGR2rF9"
      },
      "outputs": [],
      "source": [
        "// #@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n",
        "// Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "// you may not use this file except in compliance with the License.\n",
        "// You may obtain a copy of the License at\n",
        "//\n",
        "// https://www.apache.org/licenses/LICENSE-2.0\n",
        "//\n",
        "// Unless required by applicable law or agreed to in writing, software\n",
        "// distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "// See the License for the specific language governing permissions and\n",
        "// limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zBH72IXMJ3JJ"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/swift/tutorials/model_training_walkthrough\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\">TensorFlow.org에서 보기</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/swift/blob/master/docs/site/tutorials/model_training_walkthrough.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Google Colab에서 실행하기</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/swift/blob/master/docs/site/tutorials/model_training_walkthrough.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">View source on GitHub</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JtEZ1pCPn--z"
      },
      "source": [
        "# 모델 훈련 살펴보기"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LDrzLFXE8T1l"
      },
      "source": [
        "이 튜토리얼에서는 붓꽃을 종별로 분류하는 머신러닝 모델을 빌드하는 과정을 통하여 TensorFlow용 Swift를 소개합니다. 이 예제에서는 TensorFlow용 Swift를 사용하여 다음을 수행합니다.\n",
        "\n",
        "1. 모델을 빌드하고,\n",
        "2. 예제 데이터로 모델을 훈련하고,\n",
        "3. 모델을 사용하여 알려지지 않은 데이터에 대해 예측합니다.\n",
        "\n",
        "## TensorFlow 프로그래밍\n",
        "\n",
        "이 튜토리얼에서는 다음과 같은 상위 수준의 TensorFlow용 Swift 개념을 사용합니다.\n",
        "\n",
        "- Epochs API를 사용하여 데이터를 가져옵니다.\n",
        "- Swift 추상화를 사용하여 모델을 빌드합니다.\n",
        "- 순수한 Swift 라이브러리를 사용할 수 없을 때 Swift의 Python 상호 운용성을 사용하여 Python 라이브러리를 사용합니다.\n",
        "\n",
        "이 튜토리얼은 다수의 TensorFlow 프로그램과 유사하게 구성되어 있습니다.\n",
        "\n",
        "1. 데이터세트를 가져오고 구문 분석합니다.\n",
        "2. 모델 형식을 선택합니다.\n",
        "3. 모델을 훈련합니다.\n",
        "4. 모델의 효과를 평가합니다.\n",
        "5. 훈련된 모델을 사용하여 예측합니다."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yNr7H-AIoLOR"
      },
      "source": [
        "## 설치 프로그램"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1J3AuPBT9gyR"
      },
      "source": [
        "### 가져오기 구성하기\n",
        "\n",
        "TensorFlow 및 유용한 Python 모듈을 가져옵니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g4Wzg69bnwK2"
      },
      "outputs": [],
      "source": [
        "import TensorFlow\n",
        "import PythonKit"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5ms1o2W-DF1g"
      },
      "outputs": [],
      "source": [
        "// This cell is here to display the plots in a Jupyter Notebook.\n",
        "// Do not copy it into another environment.\n",
        "%include \"EnableIPythonDisplay.swift\"\n",
        "IPythonDisplay.shell.enable_matplotlib(\"inline\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "82TJnHsCY02t"
      },
      "outputs": [],
      "source": [
        "let plt = Python.import(\"matplotlib.pyplot\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-5p_tYTyz43Q"
      },
      "outputs": [],
      "source": [
        "import Foundation\n",
        "import FoundationNetworking\n",
        "func download(from sourceString: String, to destinationString: String) {\n",
        "    let source = URL(string: sourceString)!\n",
        "    let destination = URL(fileURLWithPath: destinationString)\n",
        "    let data = try! Data.init(contentsOf: source)\n",
        "    try! data.write(to: destination)\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Zx7wc0LuuxaJ"
      },
      "source": [
        "## 붓꽃 분류 문제\n",
        "\n",
        "식물학자가 되어 발견한 각 붓꽃을 자동으로 분류하는 방법을 찾는다고 상상해 보세요. 머신러닝은 꽃을 통계적으로 분류하는 다양한 알고리즘을 제공합니다. 예를 들어, 정교한 머신러닝 프로그램은 사진을 기반으로 꽃을 분류할 수 있습니다. 여기서는 그보다는 소박하게 [꽃받침](https://en.wikipedia.org/wiki/Sepal)과 [꽃잎](https://en.wikipedia.org/wiki/Petal)의 길이와 너비 측정치를 기준으로 붓꽃을 분류해 보겠습니다.\n",
        "\n",
        "붓꽃 속은 약 300종을 포함하지만 여기 프로그램에서는 다음 세 가지만 분류할 것입니다.\n",
        "\n",
        "- Iris setosa\n",
        "- Iris virginica\n",
        "- Iris versicolor\n",
        "\n",
        "<table>\n",
        "  <tr><td>     <img src=\"https://www.tensorflow.org/images/iris_three_species.jpg\" alt=\"Petal geometry compared for three iris species: Iris setosa, Iris virginica, and Iris versicolor\" class=\"\">   </td></tr>\n",
        "  <tr><td align=\"center\">     <b>Figure 1.</b> <a href=\"https://commons.wikimedia.org/w/index.php?curid=170298\">Iris setosa</a> (by <a href=\"https://commons.wikimedia.org/wiki/User:Radomil\">Radomil</a>, CC BY-SA 3.0), <a href=\"https://commons.wikimedia.org/w/index.php?curid=248095\">Iris versicolor</a>, (by <a href=\"https://commons.wikimedia.org/wiki/User:Dlanglois\">Dlanglois</a>, CC BY-SA 3.0), and <a href=\"https://www.flickr.com/photos/33397993@N05/3352169862\">Iris virginica</a> (by <a href=\"https://www.flickr.com/photos/33397993@N05\">Frank Mayfield</a>, CC BY-SA 2.0).<br>\n",
        "</td></tr>\n",
        "</table>\n",
        "\n",
        "다행히 누군가 이미 꽃받침과 꽃잎 측정치가 담긴 [120개의 붓꽃 데이터세트를](https://en.wikipedia.org/wiki/Iris_flower_data_set) 만들었습니다. 이것은 초보자 머신러닝 분류 문제에 널리 사용되는 클래식 데이터세트입니다."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3Px6KAg0Jowz"
      },
      "source": [
        "## 훈련 데이터세트 가져오기 및 구문 분석하기\n",
        "\n",
        "데이터세트 파일을 다운로드하고 이 Swift 프로그램에서 사용할 수있는 구조로 변환합니다.\n",
        "\n",
        "### 데이터세트 다운로드하기\n",
        "\n",
        "http://download.tensorflow.org/data/iris_training.csv에서 훈련 데이터세트 파일을 다운로드합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DKkgac4WO0mP"
      },
      "outputs": [],
      "source": [
        "let trainDataFilename = \"iris_training.csv\"\n",
        "download(from: \"http://download.tensorflow.org/data/iris_training.csv\", to: trainDataFilename)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qnX1-aLors4S"
      },
      "source": [
        "### 데이터 검사하기\n",
        "\n",
        "이 데이터세트 `iris_training.csv`는 쉼표로 구분된 값(CSV)으로 형식이 지정된 테이블 형식 데이터를 저장하는 일반 텍스트 파일입니다. 처음 5개 항목을 살펴보겠습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FQvb_JYdrpPm"
      },
      "outputs": [],
      "source": [
        "let f = Python.open(trainDataFilename)\n",
        "for _ in 0..<5 {\n",
        "    print(Python.next(f).strip())\n",
        "}\n",
        "f.close()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kQhzD6P-uBoq"
      },
      "source": [
        "데이터세트가 보이면 다음 사항을 확인하세요.\n",
        "\n",
        "1. 첫 번째 줄은 데이터세트에 대한 정보가 포함된 헤더입니다.\n",
        "\n",
        "- 총 120개의 예가 있습니다. 각 예에는 4개의 특성과 3개의 가능한 레이블 이름 중 하나가 있습니다.\n",
        "\n",
        "1. 후속 행은 한 줄에 하나의 *[예](https://developers.google.com/machine-learning/glossary/#example)*를 표시한 데이터 레코드이며 다음을 포함합니다.\n",
        "\n",
        "- 처음 네 개의 필드는 예제의 특성을 표시하는 *[특성](https://developers.google.com/machine-learning/glossary/#feature)*입니다. 여기에는 꽃 측정치를 나타내는 부동 숫자가 표시됩니다.\n",
        "- 마지막 열은 예측하려는 값을 표시하는 *[레이블](https://developers.google.com/machine-learning/glossary/#label)*입니다. 이 데이터세트의 경우, 꽃 이름에 해당하는 정수 값 0, 1 또는 2가 여기 표시됩니다.\n",
        "\n",
        "코드로 작성해 보겠습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9Edhevw7exl6"
      },
      "outputs": [],
      "source": [
        "let featureNames = [\"sepal_length\", \"sepal_width\", \"petal_length\", \"petal_width\"]\n",
        "let labelName = \"species\"\n",
        "let columnNames = featureNames + [labelName]\n",
        "\n",
        "print(\"Features: \\(featureNames)\")\n",
        "print(\"Label: \\(labelName)\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CCtwLoJhhDNc"
      },
      "source": [
        "각 레이블은 문자열 이름(예: 'setosa')과 연결되지만, 머신러닝은 일반적으로 숫자값에 의존합니다. 레이블 번호는 다음과 같은 명명된 표현에 매핑됩니다.\n",
        "\n",
        "- `0`: Iris setosa\n",
        "- `1`: Iris versicolor\n",
        "- `2`: Iris virginica\n",
        "\n",
        "특성 및 레이블에 대한 자세한 내용은 [머신러닝 단기 집중 과정의 ML 용어 섹션](https://developers.google.com/machine-learning/crash-course/framing/ml-terminology)을 참조하세요."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sVNlJlUOhkoX"
      },
      "outputs": [],
      "source": [
        "let classNames = [\"Iris setosa\", \"Iris versicolor\", \"Iris virginica\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dqPkQExM2Pwt"
      },
      "source": [
        "### Epochs API를 사용하여 데이터세트 만들기\n",
        "\n",
        "Swift for TensorFlow의 Epochs API는 데이터를 읽고 훈련에 사용되는 형식으로 변환하기 위한 상위 수준의 API입니다. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bBx_C6UWO0mc"
      },
      "outputs": [],
      "source": [
        "let batchSize = 32\n",
        "\n",
        "/// A batch of examples from the iris dataset.\n",
        "struct IrisBatch {\n",
        "    /// [batchSize, featureCount] tensor of features.\n",
        "    let features: Tensor<Float>\n",
        "\n",
        "    /// [batchSize] tensor of labels.\n",
        "    let labels: Tensor<Int32>\n",
        "}\n",
        "\n",
        "/// Conform `IrisBatch` to `Collatable` so that we can load it into a `TrainingEpoch`.\n",
        "extension IrisBatch: Collatable {\n",
        "    public init<BatchSamples: Collection>(collating samples: BatchSamples)\n",
        "        where BatchSamples.Element == Self {\n",
        "        /// `IrisBatch`es are collated by stacking their feature and label tensors\n",
        "        /// along the batch axis to produce a single feature and label tensor\n",
        "        features = Tensor<Float>(stacking: samples.map{$0.features})\n",
        "        labels = Tensor<Int32>(stacking: samples.map{$0.labels})\n",
        "    }\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SO6elT-kwwIK"
      },
      "source": [
        "다운로드한 데이터세트는 CSV 형식이므로 IrisBatch 객체 목록으로 데이터를 로드하는 함수를 작성해 보겠습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LwA21wYCguO5"
      },
      "outputs": [],
      "source": [
        "/// Initialize an `IrisBatch` dataset from a CSV file.\n",
        "func loadIrisDatasetFromCSV(\n",
        "        contentsOf: String, hasHeader: Bool, featureColumns: [Int], labelColumns: [Int]) -> [IrisBatch] {\n",
        "        let np = Python.import(\"numpy\")\n",
        "\n",
        "        let featuresNp = np.loadtxt(\n",
        "            contentsOf,\n",
        "            delimiter: \",\",\n",
        "            skiprows: hasHeader ? 1 : 0,\n",
        "            usecols: featureColumns,\n",
        "            dtype: Float.numpyScalarTypes.first!)\n",
        "        guard let featuresTensor = Tensor<Float>(numpy: featuresNp) else {\n",
        "            // This should never happen, because we construct featuresNp in such a\n",
        "            // way that it should be convertible to tensor.\n",
        "            fatalError(\"np.loadtxt result can't be converted to Tensor\")\n",
        "        }\n",
        "\n",
        "        let labelsNp = np.loadtxt(\n",
        "            contentsOf,\n",
        "            delimiter: \",\",\n",
        "            skiprows: hasHeader ? 1 : 0,\n",
        "            usecols: labelColumns,\n",
        "            dtype: Int32.numpyScalarTypes.first!)\n",
        "        guard let labelsTensor = Tensor<Int32>(numpy: labelsNp) else {\n",
        "            // This should never happen, because we construct labelsNp in such a\n",
        "            // way that it should be convertible to tensor.\n",
        "            fatalError(\"np.loadtxt result can't be converted to Tensor\")\n",
        "        }\n",
        "\n",
        "        return zip(featuresTensor.unstacked(), labelsTensor.unstacked()).map{IrisBatch(features: $0.0, labels: $0.1)}\n",
        "\n",
        "    }"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HbmOmUWneVGM"
      },
      "source": [
        "이제 CSV 로딩 함수를 사용하여 훈련 데이터세트를 로드하고 `TrainingEpochs` 객체를 생성할 수 있습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zFnMejfFYgSV"
      },
      "outputs": [],
      "source": [
        "let trainingDataset: [IrisBatch] = loadIrisDatasetFromCSV(contentsOf: trainDataFilename, \n",
        "                                                  hasHeader: true, \n",
        "                                                  featureColumns: [0, 1, 2, 3], \n",
        "                                                  labelColumns: [4])\n",
        "\n",
        "let trainingEpochs: TrainingEpochs = TrainingEpochs(samples: trainingDataset, batchSize: batchSize)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gB_RSn62c-3G"
      },
      "source": [
        "`TrainingEpochs` 객체는 무한한 epoch의 시퀀스입니다. 각 epoch에는 `IrisBatch`가 ​​포함됩니다. 첫 번째 epoch의 첫 번째 요소를 살펴보겠습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iDuG94H-C122"
      },
      "outputs": [],
      "source": [
        "let firstTrainEpoch = trainingEpochs.next()!\n",
        "let firstTrainBatch = firstTrainEpoch.first!.collated\n",
        "let firstTrainFeatures = firstTrainBatch.features\n",
        "let firstTrainLabels = firstTrainBatch.labels\n",
        "\n",
        "print(\"First batch of features: \\(firstTrainFeatures)\")\n",
        "print(\"firstTrainFeatures.shape: \\(firstTrainFeatures.shape)\")\n",
        "print(\"First batch of labels: \\(firstTrainLabels)\")\n",
        "print(\"firstTrainLabels.shape: \\(firstTrainLabels.shape)\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E63mArnQaAGz"
      },
      "source": [
        "첫 번째 `batchSize` 예제의 특성은 `firstTrainFeatures`로 함께 그룹화(또는 *일괄 처리*)되고 첫 번째 `batchSize` 예제의 레이블은 `firstTrainLabels`로 일괄 처리됩니다.\n",
        "\n",
        "Python의 matplotlib를 사용하여 배치에서 몇 가지 특성을 플롯하면 일부 클러스터를 볼 수 있습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "me5Wn-9FcyyO"
      },
      "outputs": [],
      "source": [
        "let firstTrainFeaturesTransposed = firstTrainFeatures.transposed()\n",
        "let petalLengths = firstTrainFeaturesTransposed[2].scalars\n",
        "let sepalLengths = firstTrainFeaturesTransposed[0].scalars\n",
        "\n",
        "plt.scatter(petalLengths, sepalLengths, c: firstTrainLabels.array.scalars)\n",
        "plt.xlabel(\"Petal length\")\n",
        "plt.ylabel(\"Sepal length\")\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LsaVrtNM3Tx5"
      },
      "source": [
        "## 모델 형식 선택하기\n",
        "\n",
        "### 왜 모델인가?\n",
        "\n",
        "*[모델](https://developers.google.com/machine-learning/crash-course/glossary#model)*은 특성과 레이블 간의 관계입니다. 붓꽃 분류 문제의 경우, 모델은 꽃받침과 꽃잎 측정치와 예측된 붓꽃 종 간의 관계를 정의합니다. 일부 간단한 모델은 몇 줄의 대수로 설명할 수 있지만, 복잡한 머신러닝 모델에는 요약하기 어려운 매개변수가 많습니다.\n",
        "\n",
        "머신러닝을 사용하지 *않고* 4가지 특성과 붓꽃 종 간의 관계를 확인할 수 있을까요? 즉, 기존 프로그래밍 기술(예: 여러 개의 조건문)을 사용하여 모델을 만들 수 있을까요? 특정 종에 대한 꽃잎과 꽃받침 측정치 간의 관계를 확인할 수 있을 만큼 충분히 오랫동안 데이터세트를 분석한 경우 가능할 수도 있습니다. 그러나 이것은 더 복잡한 데이터세트에서는 어렵거나 불가능할 수도 있습니다. 좋은 머신러닝 접근 방식이라면 *적절한 모델을 제시해 줍니다*. 적절한 머신러닝 모델 형식에 충분한 대표 예제를 제공하면 프로그램이 관계를 파악해 줍니다.\n",
        "\n",
        "### 모델 선택하기\n",
        "\n",
        "훈련할 모델의 종류를 선택해야 합니다. 많은 형식의 모델이 있으며 좋은 모델을 선택하려면 경험이 필요합니다. 이 튜토리얼에서는 신경망을 사용하여 붓꽃 분류 문제를 해결합니다. *[신경망](https://developers.google.com/machine-learning/glossary/#neural_network)*은 특성과 레이블 간의 복잡한 관계를 찾을 수 있으며, 하나 이상의 *[숨겨진 레이어](https://developers.google.com/machine-learning/glossary/#hidden_layer)*로 구성된 고도로 구조화된 그래프입니다. 각 숨겨진 레이어는 하나 이상의 *[신경](https://developers.google.com/machine-learning/glossary/#neuron)*으로 구성됩니다. 신경망에는 여러 범주가 있으며, 이 프로그램은 조밀하거나 *[완전히 연결된 신경망](https://developers.google.com/machine-learning/glossary/#fully_connected_layer)*을 사용합니다. 즉, 한 레이어의 신경은 이전 레이어의 *모든* 신경에서 입력 연결을 받습니다. 예를 들어, 그림 2는 입력 레이어, 2개의 숨겨진 레이어 및 출력 레이어로 구성된 조밀한 신경망을 보여줍니다.\n",
        "\n",
        "<table>\n",
        "  <tr><td>     <img src=\"https://www.tensorflow.org/images/custom_estimators/full_network.png\" alt=\"A diagram of the network architecture: Inputs, 2 hidden layers, and outputs\">\n",
        "</td></tr>\n",
        "  <tr><td align=\"center\">     <b>그림 2.</b> 특성, 숨겨진 레이어, 예측값으로 이루어진 신경망<br>\n",
        "</td></tr>\n",
        "</table>\n",
        "\n",
        "그림 2의 모델을 훈련하고 레이블이 지정되지 않은 예제를 제공하면, 이 꽃이 주어진 붓꽃 종일 가능성에 대한 3가지 예측값이 생성됩니다. 이 예측을 *[추론](https://developers.google.com/machine-learning/crash-course/glossary#inference)*이라고 합니다. 이 예에서 출력 예측값의 합계는 1.0입니다. 그림 2에서 이 예측은 <em>Iris setosa</em>의 경우 <code>0.02</code>, <em>Iris versicolor</em>의 경우 <code>0.95</code>, <em>Iris virginica</em>의 경우 <code>0.03</code>입니다. 즉, 모델은 95% 확률로 레이블이 지정되지 않은 예시 꽃이 <em>Iris versicolor</em>라고 예측합니다."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W23DIMVPQEBt"
      },
      "source": [
        "### TensorFlow Deep Learning 라이브러리용 Swift를 사용하여 모델 생성하기\n",
        "\n",
        "[TensorFlow Deep Learning 라이브러리용 Swift](https://github.com/tensorflow/swift-apis)를 사용하면 함께 와이어링하기 위한 기본 레이어와 규칙을 정의하여 모델을 쉽게 빌드하고 실험할 수 있습니다.\n",
        "\n",
        "모델은 [`Layer`](https://www.tensorflow.org/swift/api_docs/Protocols/Layer)를 준수하는 `struct`입니다. 즉, 입력 `Tensor`를 출력 `Tensor`에 매핑하는 [`callAsFunction(_:)`](https://www.tensorflow.org/swift/api_docs/Protocols/Layer#callasfunction_:) 메서드를 정의합니다. `callAsFunction(_:)` 메서드는 종종 하위 레이어를 통해 입력을 차례로 배열하기만 합니다. 3개의 [`Dense`](https://www.tensorflow.org/swift/api_docs/Structs/Dense) 하위 레이어를 통해 입력의 시퀀스를 생성하는 `IrisModel`을 정의해 보겠습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wr5A5WvthvZ0"
      },
      "outputs": [],
      "source": [
        "import TensorFlow\n",
        "\n",
        "let hiddenSize: Int = 10\n",
        "struct IrisModel: Layer {\n",
        "    var layer1 = Dense<Float>(inputSize: 4, outputSize: hiddenSize, activation: relu)\n",
        "    var layer2 = Dense<Float>(inputSize: hiddenSize, outputSize: hiddenSize, activation: relu)\n",
        "    var layer3 = Dense<Float>(inputSize: hiddenSize, outputSize: 3)\n",
        "    \n",
        "    @differentiable\n",
        "    func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {\n",
        "        return input.sequenced(through: layer1, layer2, layer3)\n",
        "    }\n",
        "}\n",
        "\n",
        "var model = IrisModel()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fK0vrIRv_tcc"
      },
      "source": [
        "활성화 함수는 레이어에 있는 각 노드의 출력 형상을 결정합니다. 이러한 비선형성은 중요합니다. 비선형성이 없으면 모델은 단일 레이어와 동일하기 떄문입니다. 사용 가능한 활성화가 많이 있지만, [ReLU](https://www.tensorflow.org/swift/api_docs/Functions#relu_:)는 숨겨진 레이어에 일반적입니다.\n",
        "\n",
        "숨겨진 레이어와 신경의 이상적인 수는 문제와 데이터세트에 따라 다릅니다. 머신러닝의 여러 측면과 마찬가지로 신경망의 최상의 형태를 고르기 위해서는 지식과 실험이 모두 필요합니다. 경험상 숨겨진 레이어와 신경의 수를 늘리면 일반적으로 더 강력한 모델이 생성되며 이를 효과적으로 훈련하려면 더 많은 데이터가 필요합니다."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2wFKnhWCpDSS"
      },
      "source": [
        "### 모델 사용하기\n",
        "\n",
        "이 모델이 특성 배치에 대해 어떤 역할을 하는지 간단히 살펴보겠습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sKjJGIYzO0mr"
      },
      "outputs": [],
      "source": [
        "// Apply the model to a batch of features.\n",
        "let firstTrainPredictions = model(firstTrainFeatures)\n",
        "firstTrainPredictions[0..<5]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wxyXOhwVr5S3"
      },
      "source": [
        "여기에서 각 예제는 각 클래스에 대한 [로짓](https://developers.google.com/machine-learning/crash-course/glossary#logits)을 반환합니다.\n",
        "\n",
        "이러한 로짓을 각 클래스의 확률로 변환하려면 [softmax](https://developers.google.com/machine-learning/crash-course/glossary#softmax) 함수를 사용하세요."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_tRwHZmTNTX2"
      },
      "outputs": [],
      "source": [
        "softmax(firstTrainPredictions[0..<5])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uRZmchElo481"
      },
      "source": [
        "클래스에서 `argmax`를 사용하면 예측된 클래스 인덱스가 제공됩니다. 그러나 모델은 아직 훈련되지 않았으므로 좋은 예측이 아닙니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-Jzm_GoErz8B"
      },
      "outputs": [],
      "source": [
        "print(\"Prediction: \\(firstTrainPredictions.argmax(squeezingAxis: 1))\")\n",
        "print(\"    Labels: \\(firstTrainLabels)\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Vzq2E5J2QMtw"
      },
      "source": [
        "## 모델 훈련하기\n",
        "\n",
        "*[훈련하기](https://developers.google.com/machine-learning/crash-course/glossary#training)*는 모델이 점차 최적화될 때 또는 모델이 데이터세트를 *학습하는* 머신러닝 단계입니다. 이 단계의 목표는 훈련 데이터세트의 구조에 대해 충분히 학습하여 보이지 않는 데이터를 예측하는 것입니다. 훈련 데이터세트에 대해 *너무 많이* 배우면 예측이 관측한 데이터에 대해서만 작동하고 일반화할 수 없습니다. 이런 문제를 *[과대적합](https://developers.google.com/machine-learning/crash-course/glossary#overfitting)*이라고 하며, 이는 문제를 해결하는 방법을 이해하는 대신 답을 암기하는 것과 같습니다.\n",
        "\n",
        "붓꽃 분류 문제는 *[지도 머신러닝](https://developers.google.com/machine-learning/glossary/#supervised_machine_learning)*의 예입니다. 모델은 레이블이 포함된 예시로 훈련됩니다. *[비지도 머신러닝](https://developers.google.com/machine-learning/glossary/#unsupervised_machine_learning)*에서 예시에는 레이블이 포함되지 않습니다. 대신 모델은 일반적으로 특성 사이에서 패턴을 찾습니다."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RaKp8aEjKX6B"
      },
      "source": [
        "### 손실 함수 선택하기\n",
        "\n",
        "훈련 및 평가 단계 모두 모델의 *[손실](https://developers.google.com/machine-learning/crash-course/glossary#loss)*을 계산해야 합니다. 이것은 모델의 예측이 원하는 레이블에서 얼마나 떨어져 있는지, 즉 모델의 성능이 얼마나 나쁜지를 측정합니다. 이 값을 최소화하거나 최적화하려고 합니다.\n",
        "\n",
        "이 모델은 모델의 클래스 확률 예측과 원하는 레이블을 사용하고 예제 전체에 걸쳐 평균 손실을 반환하는 [`softmaxCrossEntropy(logits:labels:)`](https://www.tensorflow.org/swift/api_docs/Functions#/s:10TensorFlow19softmaxCrossEntropy6logits6labelsAA0A0VyxGAG_AFys5Int32VGtAA0aB13FloatingPointRzlF) 함수를 사용하여 손실을 계산합니다.\n",
        "\n",
        "현재 훈련되지 않은 모델의 손실을 계산해 보겠습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tMAT4DcMPwI-"
      },
      "outputs": [],
      "source": [
        "let untrainedLogits = model(firstTrainFeatures)\n",
        "let untrainedLoss = softmaxCrossEntropy(logits: untrainedLogits, labels: firstTrainLabels)\n",
        "print(\"Loss test: \\(untrainedLoss)\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lOxFimtlKruu"
      },
      "source": [
        "### 옵티마이저 만들기\n",
        "\n",
        "*[옵티마이저](https://developers.google.com/machine-learning/crash-course/glossary#optimizer)*는 계산된 그래디언트를 모델의 변수에 적용하여 `loss` 함수를 최소화합니다. 손실 함수를 곡면으로 생각해 보세요(그림 3 참조). 곡면을 걸어 다니면서 가장 낮은 지점을 찾으려고 하는 것입니다. 그래디언트는 가장 가파른 상승 방향을 가리키므로 반대 방향으로 이동하여 경사를 내려갑니다. 각 배치의 손실과 그래디언트를 반복적으로 계산하여 훈련 중에 모델을 조정합니다. 점차적으로 모델은 손실을 최소화하기 위해 가중치와 바이어스의 최상의 조합을 찾습니다. 손실이 낮을수록 모델의 예측값이 더 좋습니다.\n",
        "\n",
        "<table>\n",
        "  <tr><td>     <img src=\"https://cs231n.github.io/assets/nn3/opt1.gif\" width=\"70%\" alt=\"Optimization algorithms visualized over time in 3D space.\">\n",
        "</td></tr>\n",
        "  <tr><td align=\"center\">     <b>그림 3.</b> 3D 공간에서 시간에 걸쳐 시각화한 최적화 알고리즘<br>(출처: <a href=\"http://cs231n.github.io/neural-networks-3/\">Stanford class CS231n</a>, MIT License, 이미지 제공: <a href=\"https://twitter.com/alecrad\">Alec Radford</a>)</td></tr>\n",
        "</table>\n",
        "\n",
        "TensorFlow용 Swift에는 훈련에 사용할 수 있는 많은 [최적화 알고리즘](https://github.com/rxwei/DeepLearning/blob/master/Sources/DeepLearning/Optimizer.swift)이 있습니다. 이 모델은 *[확률적 경사 하강](https://developers.google.com/machine-learning/crash-course/glossary#gradient_descent)*(SGD) 알고리즘을 구현하는 SGD 옵티마이저를 사용합니다. `learningRate`는 경사 아래로 반복할 때마다 사용할 단계의 크기를 설정하는데 이것은 더 나은 결과를 얻기 위해 일반적으로 조정하는 *하이퍼 매개변수*입니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8xxi2NNGKwG_"
      },
      "outputs": [],
      "source": [
        "let optimizer = SGD(for: model, learningRate: 0.01)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pJVRZ0hP52ZB"
      },
      "source": [
        "`optimizer`를 사용하여 단일 경사 하강 단계를 수행해 보겠습니다. 먼저 모델에 대한 손실의 그래디언트를 계산합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rxRNTFVe56RG"
      },
      "outputs": [],
      "source": [
        "let (loss, grads) = valueWithGradient(at: model) { model -> Tensor<Float> in\n",
        "    let logits = model(firstTrainFeatures)\n",
        "    return softmaxCrossEntropy(logits: logits, labels: firstTrainLabels)\n",
        "}\n",
        "print(\"Current loss: \\(loss)\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5B27cIT0O0nE"
      },
      "source": [
        "다음으로 방금 계산한 그래디언트를 옵티마이저에 전달하여 그에 따라 모델의 미분 변수를 업데이트합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "icyvh-o6O0nF"
      },
      "outputs": [],
      "source": [
        "optimizer.update(&model, along: grads)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nhpgM7UpO0nG"
      },
      "source": [
        "손실을 다시 계산하면 경사 하강 단계가 (일반적으로) 손실을 줄이기 때문에 더 작아져야 합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aw0OzyojAa39"
      },
      "outputs": [],
      "source": [
        "let logitsAfterOneStep = model(firstTrainFeatures)\n",
        "let lossAfterOneStep = softmaxCrossEntropy(logits: logitsAfterOneStep, labels: firstTrainLabels)\n",
        "print(\"Next loss: \\(lossAfterOneStep)\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7Y2VSELvwAvW"
      },
      "source": [
        "### 훈련 루프\n",
        "\n",
        "여기까지 모두 마쳤다면 모델을 훈련할 준비가 되었습니다! 훈련 루프는 더 나은 예측을 할 수 있도록 데이터세트 예제를 모델에 제공합니다. 다음 코드 블록은 이러한 훈련 단계를 설정합니다.\n",
        "\n",
        "1. 각 *epoch*를 반복합니다. epoch는 데이터세트를 한 번 통과하는 것을 의미합니다.\n",
        "2. 한 epoch 내에서 훈련 epoch 내의 각 배치를 반복합니다.\n",
        "3. 배치를 정렬하고 *특성* (`x`)과 *레이블* (`y`)을 가져옵니다.\n",
        "4. 정렬된 배치의 특성을 사용하여 예측을 수행하고 레이블과 비교합니다. 예측의 부정확성을 측정하고 이를 사용하여 모델의 손실 및 그래디언트를 계산합니다.\n",
        "5. 경사 하강을 사용하여 모델의 변수를 업데이트합니다.\n",
        "6. 시각화를 위해 몇 가지 통계를 추적합니다.\n",
        "7. 각 epoch에 대해 반복합니다.\n",
        "\n",
        "`epochCount` 변수는 데이터세트 모음을 반복하는 횟수입니다. 반 직관적으로 모델을 더 오래 훈련한다고 해서 더 나은 모델이 보장되는 것은 아닙니다. `epochCount`는 조정할 수 있는 *[하이퍼 매개변수](https://developers.google.com/machine-learning/glossary/#hyperparameter)*입니다. 올바른 숫자를 선택하려면 일반적으로 경험과 실험이 모두 필요합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AIgulGRUhpto"
      },
      "outputs": [],
      "source": [
        "let epochCount = 500\n",
        "var trainAccuracyResults: [Float] = []\n",
        "var trainLossResults: [Float] = []"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "066kVZQFO0nL"
      },
      "outputs": [],
      "source": [
        "func accuracy(predictions: Tensor<Int32>, truths: Tensor<Int32>) -> Float {\n",
        "    return Tensor<Float>(predictions .== truths).mean().scalarized()\n",
        "}\n",
        "\n",
        "for (epochIndex, epoch) in trainingEpochs.prefix(epochCount).enumerated() {\n",
        "    var epochLoss: Float = 0\n",
        "    var epochAccuracy: Float = 0\n",
        "    var batchCount: Int = 0\n",
        "    for batchSamples in epoch {\n",
        "        let batch = batchSamples.collated\n",
        "        let (loss, grad) = valueWithGradient(at: model) { (model: IrisModel) -> Tensor<Float> in\n",
        "            let logits = model(batch.features)\n",
        "            return softmaxCrossEntropy(logits: logits, labels: batch.labels)\n",
        "        }\n",
        "        optimizer.update(&model, along: grad)\n",
        "        \n",
        "        let logits = model(batch.features)\n",
        "        epochAccuracy += accuracy(predictions: logits.argmax(squeezingAxis: 1), truths: batch.labels)\n",
        "        epochLoss += loss.scalarized()\n",
        "        batchCount += 1\n",
        "    }\n",
        "    epochAccuracy /= Float(batchCount)\n",
        "    epochLoss /= Float(batchCount)\n",
        "    trainAccuracyResults.append(epochAccuracy)\n",
        "    trainLossResults.append(epochLoss)\n",
        "    if epochIndex % 50 == 0 {\n",
        "        print(\"Epoch \\(epochIndex): Loss: \\(epochLoss), Accuracy: \\(epochAccuracy)\")\n",
        "    }\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2FQHVUnm_rjw"
      },
      "source": [
        "### 시간 경과에 따른 손실 함수 시각화하기"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j3wdbmtLVTyr"
      },
      "source": [
        "모델의 훈련 진행 상황을 출력하는 것도 유용하지만, 진행 상황을 시각적으로 보는 것이 *더* 도움이 되는 경우가 많습니다. Python의 `matplotlib` 모듈을 사용하여 기본 차트를 만들 수 있습니다.\n",
        "\n",
        "이러한 차트를 해석하려면 어느 정도의 경험이 필요하지만, *궁극적인 목표는 손실*이 감소하고 *정확성*이 증가하는 것입니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "agjvNd2iUGFn"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize: [12, 8])\n",
        "\n",
        "let accuracyAxes = plt.subplot(2, 1, 1)\n",
        "accuracyAxes.set_ylabel(\"Accuracy\")\n",
        "accuracyAxes.plot(trainAccuracyResults)\n",
        "\n",
        "let lossAxes = plt.subplot(2, 1, 2)\n",
        "lossAxes.set_ylabel(\"Loss\")\n",
        "lossAxes.set_xlabel(\"Epoch\")\n",
        "lossAxes.plot(trainLossResults)\n",
        "\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "axA6WuGVO0nR"
      },
      "source": [
        "그래프의 y 축이 0부터 시작하지 않는 점을 유의하세요."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Zg8GoMZhLpGH"
      },
      "source": [
        "## 모델의 효과 평가하기\n",
        "\n",
        "이제 모델이 훈련되었으므로 성능에 대한 통계를 얻을 수 있습니다.\n",
        "\n",
        "*평가*는 모델이 얼마나 효과적으로 예측을 수행하는지 알아보는 것을 의미합니다. 붓꽃 분류에서 모델의 효과를 확인하려면 꽃받침과 꽃잎 측정치를 모델에 전달하고 모델이 붓꽃 종을 예측하도록 요청합니다. 그런 다음 모델의 예측을 실제 레이블과 비교합니다. 예를 들어, 입력 예제의 절반에서 올바른 종을 선택한 모델의 *[정확성](https://developers.google.com/machine-learning/glossary/#accuracy)*은 `0.5`입니다. 그림 4는 약간 더 효과적인 모델을 보여줍니다. 5개 예측 중 4개는 80% 정확성으로 정확합니다.\n",
        "\n",
        "<table cellpadding=\"8\" border=\"0\">\n",
        "  <colgroup>\n",
        "    <col span=\"4\">\n",
        "    <col span=\"1\" bgcolor=\"lightblue\">\n",
        "    <col span=\"1\" bgcolor=\"lightgreen\">\n",
        "  </colgroup>\n",
        "  <tr bgcolor=\"lightgray\">\n",
        "    <th colspan=\"4\">예시 특성</th>\n",
        "    <th colspan=\"1\">레이블</th>\n",
        "    <th colspan=\"1\">모델 예측</th>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td>5.9</td>\n",
        "<td>3.0</td>\n",
        "<td>4.3</td>\n",
        "<td>1.5</td>\n",
        "<td align=\"center\">1</td>\n",
        "<td align=\"center\">1</td>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td>6.9</td>\n",
        "<td>3.1</td>\n",
        "<td>5.4</td>\n",
        "<td>2.1</td>\n",
        "<td align=\"center\">2</td>\n",
        "<td align=\"center\">2</td>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td>5.1</td>\n",
        "<td>3.3</td>\n",
        "<td>1.7</td>\n",
        "<td>0.5</td>\n",
        "<td align=\"center\">0</td>\n",
        "<td align=\"center\">0</td>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td>6.0</td> <td>3.4</td> <td>4.5</td> <td>1.6</td> <td align=\"center\">1</td>\n",
        "<td align=\"center\" bgcolor=\"red\">2</td>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td>5.5</td>\n",
        "<td>2.5</td>\n",
        "<td>4.0</td>\n",
        "<td>1.3</td>\n",
        "<td align=\"center\">1</td>\n",
        "<td align=\"center\">1</td>\n",
        "  </tr>\n",
        "  <tr><td align=\"center\" colspan=\"6\">     <b>그림 4.</b> 정확성 80%의 붓꽃 분류기<br>\n",
        "</td></tr>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z-EvK7hGL0d8"
      },
      "source": [
        "### 테스트 데이터세트 설정하기\n",
        "\n",
        "모델 평가는 모델 훈련과 유사합니다. 가장 큰 차이점은 예제가 훈련 세트가 아닌 별도의 *[테스트 세트](https://developers.google.com/machine-learning/crash-course/glossary#test_set)*에서 나온다는 것입니다. 모델의 효과를 공정하게 평가하려면 모델을 평가하는 데 사용되는 예가 모델 훈련에 사용된 예와 달라야 합니다.\n",
        "\n",
        "테스트 데이터세트의 설정은 훈련 데이터세트의 설정과 유사합니다. http://download.tensorflow.org/data/iris_test.csv에서 테스트 세트를 다운로드합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SRMWCu30bnxH"
      },
      "outputs": [],
      "source": [
        "let testDataFilename = \"iris_test.csv\"\n",
        "download(from: \"http://download.tensorflow.org/data/iris_test.csv\", to: testDataFilename)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jEPPL6FUO0nV"
      },
      "source": [
        "이제 `IrisBatch` 배열에 로드합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w6SCt95HO0nW"
      },
      "outputs": [],
      "source": [
        "let testDataset = loadIrisDatasetFromCSV(\n",
        "    contentsOfCSVFile: testDataFilename, hasHeader: true,\n",
        "    featureColumns: [0, 1, 2, 3], labelColumns: [4]).inBatches(of: batchSize)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HFuOKXJdMAdm"
      },
      "source": [
        "### 테스트 데이터세트에서 모델 평가하기\n",
        "\n",
        "훈련 단계와 달리 모델은 테스트 데이터의 단일 [epoch](https://developers.google.com/machine-learning/glossary/#epoch)만 평가합니다. 다음 코드 셀에서 테스트 세트의 각 예제를 반복하고 모델의 예측값을 실제 레이블과 비교합니다. 이것은 전체 테스트 세트에서 모델의 정확성을 측정하는 데 사용됩니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Tj4Rs8gwO0nY"
      },
      "outputs": [],
      "source": [
        "// NOTE: Only a single batch will run in the loop since the batchSize we're using is larger than the test set size\n",
        "for batchSamples in testDataset {\n",
        "    let batch = batchSamples.collated\n",
        "    let logits = model(batch.features)\n",
        "    let predictions = logits.argmax(squeezingAxis: 1)\n",
        "    print(\"Test batch accuracy: \\(accuracy(predictions: predictions, truths: batch.labels))\")\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HcKEZMtCOeK-"
      },
      "source": [
        "예를 들어, 첫 번째 배치에서 모델이 일반적으로 올바른 것을 확인할 수 있습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uNwt2eMeOane"
      },
      "outputs": [],
      "source": [
        "let firstTestBatch = testDataset.first!.collated\n",
        "let firstTestBatchLogits = model(firstTestBatch.features)\n",
        "let firstTestBatchPredictions = firstTestBatchLogits.argmax(squeezingAxis: 1)\n",
        "\n",
        "print(firstTestBatchPredictions)\n",
        "print(firstTestBatch.labels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7Li2r1tYvW7S"
      },
      "source": [
        "## 훈련된 모델을 사용하여 예측하기\n",
        "\n",
        "모델을 훈련하고 이 모델이 붓꽃 종을 분류하는 데 훌륭하지만 완벽하지는 않다는 것을 증명했습니다. 이제 훈련된 모델을 사용하여 [레이블이 없는 예](https://developers.google.com/machine-learning/glossary/#unlabeled_example)에 대한 예측을 수행해 보겠습니다. 즉, 특성은 포함하지만 레이블은 포함하지 않는 예입니다.\n",
        "\n",
        "실제로 라벨이 지정되지 않은 예는 앱, CSV 파일, 데이터 피드 등 다양한 소스에서 제공될 수 있습니다. 지금은 레이블을 예측하기 위해 레이블이 없는 3가지 예제를 수동으로 제공할 것입니다. 레이블 번호는 다음과 같이 명명된 표현에 매핑됩니다.\n",
        "\n",
        "- `0`: Iris setosa\n",
        "- `1`: Iris versicolor\n",
        "- `2`: Iris virginica"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MTYOZr27O0ne"
      },
      "outputs": [],
      "source": [
        "let unlabeledDataset: Tensor<Float> =\n",
        "    [[5.1, 3.3, 1.7, 0.5],\n",
        "     [5.9, 3.0, 4.2, 1.5],\n",
        "     [6.9, 3.1, 5.4, 2.1]]\n",
        "\n",
        "let unlabeledDatasetPredictions = model(unlabeledDataset)\n",
        "\n",
        "for i in 0..<unlabeledDatasetPredictions.shape[0] {\n",
        "    let logits = unlabeledDatasetPredictions[i]\n",
        "    let classIdx = logits.argmax().scalar!\n",
        "    print(\"Example \\(i) prediction: \\(classNames[Int(classIdx)]) (\\(softmax(logits)))\")\n",
        "}"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "model_training_walkthrough.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Swift",
      "name": "swift"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
